-- Optimized branchless binary search over sorted vectors -*- lua -*-
--
-- An optimized implementation of branchless binary search, following
-- the article by Paul Khuoung, "Binary search *eliminates* branch
-- misprediction.":
--
--   http://www.pvk.ca/Blog/2012/07/03/binary-search-star-eliminates-star-branch-mispredictions/

module(..., package.seeall)

local debug = false

local ffi = require("ffi")
local bit = require("bit")
local C = ffi.C

local dasm = require("dasm")

|.arch x64
|.actionlist actions

-- Table keeping machine code alive to the GC.
local anchor = {}

-- Utility: assemble code and optionally dump disassembly.
local function assemble (name, prototype, generator)
   local Dst = dasm.new(actions)
   generator(Dst)
   local mcode, size = Dst:build()
   table.insert(anchor, mcode)
   if debug then
      print("mcode dump: "..name)
      dasm.dump(mcode, size)
   end
   return ffi.cast(prototype, mcode)
end

function gen(count, entry_type)
   local function gen_binary_search(Dst)
      if count == 1 then
         | mov rax, rdi
         | ret
         return
      end

      local entry_byte_size = ffi.sizeof(entry_type)
      local size = 1
      while size < count do size = size * 2 end

      -- Initially, the vector is in edi and the key we are looking for
      -- is in esi.  Save the vector pointer in rdx.
      | mov rdx, rdi

      -- In the first bisection, make sure the rest of the bisections
      -- have a power-of-two size.
      do
         local next_size = size / 2
         local mid = next_size - 1
         local mid_offset = mid * entry_byte_size
         local hi_offset = (count - next_size) * entry_byte_size
         | cmp [rdi + mid_offset], esi
         | lea rax, [rdi + hi_offset]
         | cmovb rdi, rax
         size = size / 2
      end

      -- In the rest, just burn down the halves.  Wheeee!
      while size > 1 do
         local next_size = size / 2
         local mid = next_size - 1
         local mid_offset = mid * entry_byte_size
         local hi_offset = next_size * entry_byte_size
         | cmp [rdi + mid_offset], esi
         | lea rax, [rdi + hi_offset]
         | cmovb rdi, rax
         size = next_size
      end

      -- Now rdi points at the answer (if we have one).  Done!
      | mov rax, rdi
      | ret
   end
   return assemble("binary_search_"..count,
                   ffi.typeof("$*(*)($*, uint32_t)", entry_type, entry_type),
                   gen_binary_search)
end

function selftest ()
   print("selftest: binary_search")
   local test = ffi.new('uint32_t[15]',
                        { 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5 })
   local searchers = {}
   for i=1,10 do searchers[i] = gen(i, ffi.typeof('uint32_t')) end

   local function assert_search(size, key, expected)
      local res = searchers[size](test, key) - test
      if res ~= expected then
         error(('in search of size %d for key %d: expected %d, got %d'):format(
                  size, key, expected, res))
      end
   end

   for i=1,10 do
      assert_search(i, 0, 0)
      assert_search(i, 1, 0)
      assert_search(i, 6, i - 1)
   end

   for i=2,10 do
      assert_search(i, 2, 1)
   end

   for i=4,10 do
      assert_search(i, 3, 3)
   end

   for i=7,10 do
      assert_search(i, 4, 6)
   end

   print("selftest: ok")
end
